import numpy as np
import torch
import copy
import time
import torch.nn as nn
from tqdm import tqdm
from fedlearn.models.models import choose_model
from torch.utils.data import DataLoader
from fedlearn.utils.metric import Metrics
from fedlearn.utils.model_utils import global_EOD, global_DP


class Server(object):
    def __init__(self, dataset, options, name = ''):
        # Basic parameters
        self.gpu = options['gpu']
        self.device = options['device']
        self.all_train_data_num = 0
        self.num_users = options['num_users']
        self.num_round = options['num_round']
        self.clients_per_round = options['clients_per_round']
        self.eval_round = options['eval_round']
        self.local_lr = options['local_lr']
        self.test_local = options['test_local']
        self.num_local_round = options['num_local_round']
        self.fair_measure = options['fairness_measure']
        self.print = options['print_result']
        self.client = []
        self.data_info = options['data_info']
        self.fairness_constraints = options['fairness_constraints']
        self.fair_metric = self.fairness_constraints['fairness_measure']
        self.options = options

        self.model = choose_model(options)
        if 'gpu' in options and (options['gpu'] is True):
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            self.model = self.model.to(device)
        self.latest_model = self.model.get_flat_params()
        self.local_model_dim = self.latest_model.shape[0]
        self.algo = options['algorithm']

    def setup_clients(self, Client, dataset, options):
        self.train_data, self.val_data, self.test_data = dataset
        self.all_train_data_num = sum(list(self.train_data['num_samples'].values()))
        self.num_users = options['num_users']
        # print(f'Number of users = {self.num_users}')

        # Load selected client
        all_clients = []
        for user in self.train_data['users']:
            c = Client(user, self.train_data['user_data'][user], self.val_data['user_data'][user], self.test_data['user_data'][user], options, copy.deepcopy(self.model))
            # init client with ID, train_data, test_data, option, model, append client to list 
            all_clients.append(c)
        
        return all_clients
    
    def select_clients(self, clients, seed = None):
        """Selects num_clients clients from possible clients"""
        num_clients = min(self.clients_per_round, len(clients))
        if seed:
            np.random.seed(seed)
        return np.random.choice(clients, num_clients, replace = False).tolist()
    
    def test(self, clients, current_round):
        begin_time = time.time()
        client_stats, ids = self.local_test(clients)
        end_time = time.time()

        train_acc = sum(client_stats['train']['accs']) / sum(client_stats['train']['num_samples'])
        train_loss = sum(client_stats['train']['losses']) / sum(client_stats['train']['num_samples'])
        train_fairness = self.global_fairness(client_stats['train'], 'train')
        
        test_acc = sum(client_stats['test']['accs']) / sum(client_stats['test']['num_samples'])
        test_loss = sum(client_stats['test']['losses']) / sum(client_stats['test']['num_samples'])
        test_fairness = self.global_fairness(client_stats['test'], 'test')

        model_stats = {'train_loss':train_loss, 'train_acc':train_acc, 'train_fairness':train_fairness,
                        'test_acc':test_acc, 'test_loss':test_loss, 'test_fairness':test_fairness,
                        'time': end_time - begin_time}
        
        # if self.print:
        #     self.print_result(current_round, model_stats, client_stats)
        self.print_result(current_round, model_stats, client_stats)

        return model_stats
    
    def local_test(self, clients):
        assert self.latest_model is not None

        stats_dict = {'num_samples':[], 'accs':[], 'losses':[],'local_fair_measure':[],'local_confusion':[]}
        client_stats= {'train':copy.deepcopy(stats_dict), 
                       'test' :copy.deepcopy(stats_dict)}

        for c in clients:
            if self.test_local:
                c.set_params(c.local_params)
            else:
                c.set_params(self.latest_model)
            client_test_dict = c.local_eval()

            for dataset, test_dict in client_test_dict.items():
                # dataset is train and test
                client_stats[dataset]['num_samples'].append(test_dict['num'])
                client_stats[dataset]['accs'].append(test_dict['acc'])
                client_stats[dataset]['losses'].append(test_dict['loss'])
                client_stats[dataset]['local_fair_measure'].append(test_dict['local_fair_measure'])
                client_stats[dataset]['local_confusion'].append(test_dict['local_confusion'])

        ids = [c.cid for c in clients]

        return client_stats, ids
    
    def global_fairness(self, client_stats, datatype='test'):
        if self.options['fairness_type'] == 'groupwise':
            assert len(list(self.data_info['A_info'].keys())) == 2
            assert len(list(self.data_info['Y_info'].keys())) == 2
            fair_gap = {m : 0 for m in ["DP","EO"]}
            if "DP" in self.options['fairness_measure']:
                confusion_matrix = client_stats['local_confusion']
                num_samples = client_stats['num_samples']
                total_pred_Y1_A1 = 0
                total_pred_Y1_A0 = 0
                total_A1 = 0
                total_A0 = 0
                for i in range(len(num_samples)):
                    total_pred_Y1_A1 += sum(v for y_dict in confusion_matrix[i][1].values() for p, v in y_dict.items() if p == 1)
                    total_pred_Y1_A0 += sum(v for y_dict in confusion_matrix[i][0].values() for p, v in y_dict.items() if p == 1)
                    total_A1 += sum(v for y_dict in confusion_matrix[i][1].values() for v in y_dict.values())
                    total_A0 += sum(v for y_dict in confusion_matrix[i][0].values() for v in y_dict.values())
                DP = torch.abs(total_pred_Y1_A1 / total_A1 - total_pred_Y1_A0/ total_A0)
                fair_gap.update({'DP':float(DP)})
            if "EO" in self.options['fairness_measure']:
                confusion_matrix = client_stats['local_confusion']
                num_samples = client_stats['num_samples']
                total_pred_Y1_Y1_A1 = 0
                total_pred_Y1_Y1_A0 = 0
                total_pred_Y1_Y0_A1 = 0
                total_pred_Y1_Y0_A0 = 0
                total_Y1_A1 = 0
                total_Y1_A0 = 0
                total_Y0_A1 = 0
                total_Y0_A0 = 0
                for i in range(len(num_samples)):
                    total_pred_Y1_Y1_A1 += confusion_matrix[i][1][1][1]
                    total_pred_Y1_Y1_A0 += confusion_matrix[i][0][1][1]
                    total_pred_Y1_Y0_A1 += confusion_matrix[i][1][0][1]
                    total_pred_Y1_Y0_A0 += confusion_matrix[i][0][0][1]
                    total_Y1_A1 += sum(confusion_matrix[i][1][1].values())
                    total_Y1_A0 += sum(confusion_matrix[i][0][1].values())
                    total_Y0_A1 += sum(confusion_matrix[i][1][0].values())
                    total_Y0_A0 += sum(confusion_matrix[i][0][0].values())
                EO_Y1 = torch.abs(total_pred_Y1_Y1_A1 / total_Y1_A1 - total_pred_Y1_Y1_A0/ total_Y1_A0)
                EO_Y0 = torch.abs(total_pred_Y1_Y0_A1 / total_Y0_A1 - total_pred_Y1_Y0_A0/ total_Y0_A0)
                EO = torch.max(EO_Y1,EO_Y0)
                fair_gap.update({'EO':float(EO)})
        if self.options['fairness_type'] == 'subgroup':
            assert len(list(self.data_info['A_info'].keys())) >= 2
            assert len(list(self.data_info['Y_info'].keys())) > 2
            fair_gap = {m : 0 for m in ["multi_DP","multi_EOP"]}
            Ylabel = self.options['data_info']['Y_info'].keys()
            Alabel = self.options['data_info']['A_info'].keys()
            if "DP" in self.options['fairness_measure']:
                confusion_matrix = client_stats['local_confusion']
                num_samples = client_stats['num_samples']
                total_samples = sum(num_samples)
                total_pred_Y = {y: 0 for y in Ylabel} #n_py
                total_pred_Y_A = {a:{y: 0 for y in Ylabel} for a in Alabel} #n_{a,py}
                total_A = {a: 0 for a in Alabel} #n_a
                for i in range(len(num_samples)):
                    for y in Ylabel:
                        total_pred_Y[y] += sum(v for d1 in confusion_matrix[i].values() for d2 in d1.values() for p, v in d2.items() if p == y)
                    for a in Alabel:
                        total_A[a] += sum(v for y_dict in confusion_matrix[i][a].values() for v in y_dict.values())
                        for y in Ylabel:
                            total_pred_Y_A[a][y] += sum(v for y_dict in confusion_matrix[i][a].values() for p, v in y_dict.items() if p == y)
                DP_disp = [total_pred_Y_A[a][y] / total_A[a] - total_pred_Y[y]/total_samples for a in Alabel for y in Ylabel]
                DP = torch.max(torch.abs(torch.tensor(DP_disp)))
                fair_gap.update({'multi_DP':float(DP)})
            if "EOP" in self.options['fairness_measure']:
                confusion_matrix = client_stats['local_confusion']
                num_samples = client_stats['num_samples']
                total_samples = sum(num_samples)
                total_Y = {y:0 for y in Ylabel} # n_y
                total_Y_pred_Y = {y: {pred_y:0 for pred_y in Ylabel} for y in Ylabel} #n_y_py
                total_Y_pred_Y_A = {a:{y: {pred_y:0 for pred_y in Ylabel} for y in Ylabel}for a in Alabel} #n_a_y_py
                total_Y_A = {a:{y: 0 for y in Ylabel} for a in Alabel} #n_{a,y}
                for i in range(len(num_samples)):
                    for y in Ylabel:
                        total_Y[y] += sum(confusion_matrix[i][a_temp][y][pred_y] for a_temp in Alabel for pred_y in Ylabel)
                        for a in Alabel:
                            total_Y_A[a][y] += sum([confusion_matrix[i][a][y][y_temp] for y_temp in Ylabel])
                        for pred_y in Ylabel:
                            total_Y_pred_Y[y][pred_y] += sum([confusion_matrix[i][a][y][pred_y] for a in Alabel])
                            for aa in Alabel:
                                total_Y_pred_Y_A[aa][y][pred_y] += confusion_matrix[i][aa][y][pred_y]
                
                EOP_disp = [total_Y_pred_Y_A[a][y][y] / total_Y_A[a][y] - total_Y_pred_Y[y][y]/total_Y[y] for a in Alabel for y in Ylabel]
                EOP = torch.max(torch.abs(torch.tensor(EOP_disp)))
                fair_gap.update({'multi_EOP':float(EOP)})
        return fair_gap
    
    def print_result(self, current_round, model_stats, client_stats):
        tqdm.write("\n >>> Round: {: >4d} / Acc: {:.3%} / Loss: {:.4f} / Time: {:.2f}s /  Fairness: {} ".format(
            current_round, model_stats['train_acc'], model_stats['train_loss'], model_stats['time'], model_stats['train_fairness']))
        
        if self.options['fairness_type'] == 'groupwise':
            for c in range(self.num_users):
                tqdm.write('>>> Client: {} / Acc: {:.3%} / Loss: {:.4f} / Fair_DP(train):{:.4f} / Fair_EO(train):{:.4f}'.format(
                            self.train_data['users'][c], 
                            client_stats['train']['accs'][c] / client_stats['train']['num_samples'][c], 
                            client_stats['train']['losses'][c] / client_stats['train']['num_samples'][c],
                            client_stats['train']['local_fair_measure'][c]['DP'],
                            client_stats['train']['local_fair_measure'][c]['EO'],
                        ))
            tqdm.write('=' * 102 + "\n")

            if current_round % self.eval_round == 0:
                tqdm.write("\n = Test = round: {} / acc: {:.3%} / loss: {:.4f} / Fairness: {}  ".format(
                    current_round, model_stats['test_acc'],model_stats['test_loss'], model_stats['test_fairness']))
                for c in range(self.num_users):
                    tqdm.write('=== Test  Client: {} / Acc: {:.3%} / Fair_DP(test):{:.4f} / Fair_EO(test):{:.4f}'.format(
                                self.test_data['users'][c], 
                                client_stats['test']['accs'][c] / client_stats['test']['num_samples'][c], 
                                client_stats['test']['local_fair_measure'][c]['DP'],
                                client_stats['test']['local_fair_measure'][c]['EO'],
                            ))
                M_fair = {}
                for fair_notion in ["DP","EO"]: 
                    local_fair = [client_stats['test']['local_fair_measure'][c][fair_notion] for c in range(self.num_users)]
                    fair_metric = max(local_fair) if fair_notion == 'DP' else sum(local_fair) / len(local_fair)
                    M_fair.update({fair_notion:fair_metric})
                tqdm.write(f"=== Test local fair: {M_fair}")
        
        elif self.options['fairness_type'] == 'subgroup':

            for c in range(self.num_users):
                tqdm.write('>>> Client: {} / Acc: {:.3%} / Loss: {:.4f} / Fair_multi_DP(train):{:.4f} / Fair_multi_EOP(train):{:.4f}'.format(
                            self.train_data['users'][c], 
                            client_stats['train']['accs'][c] / client_stats['train']['num_samples'][c], 
                            client_stats['train']['losses'][c] / client_stats['train']['num_samples'][c],
                            client_stats['train']['local_fair_measure'][c]['multi_DP'],
                            client_stats['train']['local_fair_measure'][c]['multi_EOP'],
                        ))
            tqdm.write('=' * 102 + "\n")

            if current_round % self.eval_round == 0:
                tqdm.write("\n = Test = round: {} / acc: {:.3%} / loss: {:.4f} / Fairness: {}  ".format(
                    current_round, model_stats['test_acc'],model_stats['test_loss'], model_stats['test_fairness']))
                for c in range(self.num_users):
                    tqdm.write('=== Test  Client: {} / Acc: {:.3%} / Fair_DP(test):{:.4f} / Fair_EO(test):{:.4f}'.format(
                                self.test_data['users'][c], 
                                client_stats['test']['accs'][c] / client_stats['test']['num_samples'][c], 
                                client_stats['test']['local_fair_measure'][c]['multi_DP'],
                                client_stats['test']['local_fair_measure'][c]['multi_EOP'],
                            ))
                M_fair = {}
                for fair_notion in ["multi_DP","multi_EOP"]: 
                    local_fair = [client_stats['test']['local_fair_measure'][c][fair_notion] for c in range(self.num_users)]
                    fair_metric = max(local_fair) if fair_notion == 'multi_DP' else sum(local_fair) / len(local_fair)
                    M_fair.update({fair_notion:fair_metric})
                tqdm.write(f"=== Test local fair: {M_fair}")